{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction to Text Classification\n",
    "\n",
    "*Copyright (c) 2022 Institute for Quantum Computing, Baidu Inc. All Rights Reserved.*\n",
    "\n",
    "Natural language processing (NLP) is a machine learning technology that gives computers the ability to interpret, manipulate, and comprehend human language. Organizations today have large volumes of voice and text data from various communication channels like emails, text messages, social media newsfeeds, video, audio, and more. They use NLP software to automatically process this data, analyze the intent or sentiment in the message, and respond in real time to human communication.\n",
    "\n",
    "The text classification task is one of the fundamental tasks in NLP. It predicts the category of the input text. It is the technique behind the applications such as news headline classification and sentiment analysis.\n",
    "\n",
    "Here, we use news headline classification as an example to demonstrate the power of Quantum Machine Learning (QML) for the text classification problems.\n",
    "\n",
    "We use two types of news headlines, real estate industry and automobile industry, as the dataset to classify them. In this dataset, the training set consists of 400 text data and the test set consists of 100 text data. A sample of the data is as follows.\n",
    "\n",
    "- 奔驰GLS怎么样？\n",
    "- 如何评价襄阳房价？\n",
    "- 南宁的城建怎么样？\n",
    "- 保时捷为什么这么贵"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## News Headline Classification Using QSANN Model\n",
    "\n",
    "### Introduction to the QSANN Model\n",
    "\n",
    "Quantum Self-Attention Neural Networks (QSANN) is a hybrid quantum-classical algorithm in a supervised learning framework. It uses the parameterized quantum circuit (PQC) to encode features on text data, the self-attention mechanism for feature extraction, and finally a fully connected neural network to process the classification results.\n",
    "\n",
    "In summary, the general principle of QSANN is as follows.\n",
    "\n",
    "1. Each word of the input text is mapped into a corresponding parameterized quantum circuit. The quantum state obtained from the evolution of this circuit is the corresponding feature representation of this word.\n",
    "2. Use the self-attention mechanism to process the quantum state and obtain the processed feature representation.\n",
    "3. Use the fully connected neural network to process the obtained features and get the predicted classification results.\n",
    "\n",
    "### Workflow\n",
    "\n",
    "QSANN is a learning model. So we need to first train the model using the dataset. After the training converges, we get a trained model which can classify the data corresponding to the task. Thus, the workflow is as follows.\n",
    "\n",
    "1. Prepare the dataset.\n",
    "2. Train with the dataset to obtain a trained model.\n",
    "3. Use the model to predict the input text and get the prediction results."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How to Use\n",
    "\n",
    "### Predict Using the Model\n",
    "\n",
    "Here, we have presented a trained model that can be directly used for news headline classification in the real estate industry and the automobile industry. Just make the corresponding configuration in the configuration file `example.toml` and enter the command `python qsann_classification.py --config example.toml` to use the trained QSANN model for predicting the input text.\n",
    "\n",
    "### Online Demo\n",
    "\n",
    "Here, we give a version of the demo that can be tested online. First define the contents of the configuration file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_toml = r\"\"\"\n",
    "# The overall configuration file of the model.\n",
    "# Enter the current task, which can be 'train' or 'test', representing training and prediction respectively. Here we use test, indicating that we want to make a prediction.\n",
    "task = 'test'\n",
    "# The text to be tested.\n",
    "text = '奔驰GLS怎么样？'\n",
    "# The path of the trained model, which will be loaded.\n",
    "model_path = 'qsann.pdparams'\n",
    "# The path of the vocabulary file in the dataset.\n",
    "vocab_path = 'headlines500/vocab.txt'\n",
    "# The number of qubits which the quantum circuit contains.\n",
    "num_qubits = 6\n",
    "# The number of the self-attention layers.\n",
    "num_layers = 1\n",
    "# The depth of the embedding circuit.\n",
    "depth_ebd = 1\n",
    "# The depth of the query circuit.\n",
    "depth_query = 1\n",
    "# The depth of the key circuit.\n",
    "depth_key = 1\n",
    "# The depth of the value circuit.\n",
    "depth_value = 1\n",
    "# The classes of input text to be predicted.\n",
    "classes = ['房地产行业', '汽车行业']\n",
    "\"\"\"\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next is the code for the prediction section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The input text is 奔驰GLS怎么样？.\n",
      "The prediction of the model is 汽车行业.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'\n",
    "\n",
    "import toml\n",
    "from paddle_quantum.qml.qsann import train, inference\n",
    "\n",
    "config = toml.loads(test_toml)\n",
    "task = config.pop('task')\n",
    "if task == 'train':\n",
    "    train(**config)\n",
    "elif task == 'test':\n",
    "    prediction = inference(**config)\n",
    "    text = config['text']\n",
    "    print(f'The input text is {text}.')\n",
    "    print(f'The prediction of the model is {prediction}.')\n",
    "else:\n",
    "    raise ValueError(\"Unknown task, it can be train or test.\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, we only need to modify the content of the text in the configuration file, and then run the entire code to quickly test other texts."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Note\n",
    "\n",
    "Here, we provide models for text classification of news headlines in the automotive and real estate industries. Developers can also use their own datasets to train the corresponding models.\n",
    "\n",
    "### The structure of the dataset\n",
    "\n",
    "If you want to use a custom dataset for training, you just need to prepare the dataset according to the rules. Prepare `train.txt` and `test.txt` in the dataset folder, and `dev.txt` if a validation set is needed. One line is used to represent one piece of data in each file. Each line contains text and a corresponding label, separated by tabs. Text is composed of space-separated words.\n",
    "\n",
    "### Introduction to the Configuration File\n",
    "\n",
    "In `test.toml`, there is a complete reference to the configuration files needed for testing. In `train.toml`, there is a complete reference to the configuration files needed for training. Use `python qsann_classification --config train.toml` to train the model. Use `python qsann_classification --config test.toml` to load the trained model for testing.\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Citation\n",
    "\n",
    "```tex\n",
    "@article{li2022quantum,\n",
    "  title={Quantum Self-Attention Neural Networks for Text Classification},\n",
    "  author={Li, Guangxi and Zhao, Xuanqiang and Wang, Xin},\n",
    "  journal={arXiv preprint arXiv:2205.05625},\n",
    "  year={2022}\n",
    "}\n",
    "```\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.15 (default, Nov 10 2022, 12:46:26) \n[Clang 14.0.6 ]"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "49b49097121cb1ab3a8a640b71467d7eda4aacc01fc9ff84d52fcb3bd4007bf1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
